/*
 *   Copyright 2012, Thomas Kerber
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 */
package milk.jpatch.code;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.bcel.classfile.ConstantPool;
import org.apache.bcel.classfile.ConstantCP;
import org.apache.bcel.classfile.ConstantNameAndType;
import org.apache.bcel.generic.BranchHandle;
import org.apache.bcel.generic.BranchInstruction;
import org.apache.bcel.generic.IfInstruction;
import org.apache.bcel.generic.Instruction;
import org.apache.bcel.generic.CPInstruction;
import org.apache.bcel.generic.InstructionHandle;
import org.apache.bcel.generic.InstructionList;
import org.apache.bcel.generic.LocalVariableInstruction;

import static org.apache.bcel.Constants.*;

/**
 * Defines methods for working with method bytecode.
 * @author Thomas Kerber
 * @version 1.0.1
 */
public class CodeUtil{
    
    /**
     * 
     * @param instr An method call instruction.
     * @param cpool The constant pool.
     * @return The modification of the size of the stack by the instruction.
     * @throws ClassCastException if the instruction doesn't call a valid
     *     method.
     */
    public static int getMethodStackModif(CPInstruction instr,
            ConstantPool cpool){
        return getMethodStackModif(instr.getIndex(), cpool);
    }
    
    /**
     * 
     * @param methodIndex The index of the method in the cpool.
     * @param cpool The constant pool.
     * @return The modification of the size of the stack by the instruction.
     */
    public static int getMethodStackModif(int methodIndex, ConstantPool cpool){
        int nameAndTypeIndex = ((ConstantCP)cpool.getConstant(methodIndex)).
                getNameAndTypeIndex();
        String sig = ((ConstantNameAndType)cpool.getConstant(nameAndTypeIndex)).
                getSignature(cpool);
        
        int pos = 1; // Open Brackets skipped.
        // Note: for the purposes of this method, the return type counts as a
        // "negative" parameter.
        int params = 0;
        
        while(true){
            // Breaks when close parens are found.
            if(sig.charAt(pos) == ')'){
                // Return type void is the *only* type which doesn't increment
                // The stack.
                if(sig.charAt(++pos) != 'V')
                    params--;
                break;
            }
            // Else another parameter was found.
            params++;
            // And it moves to the end of it.
            if(sig.charAt(pos) == 'L'){
                while(sig.charAt(pos++) != ';');
            }
            else{
                while(true){
                    char c = sig.charAt(pos++);
                    if(c == 'B' || c == 'C' || c == 'D' || c == 'F' ||
                            c == 'I'|| c == 'J' || c == 'S' || c == 'Z')
                        break;
                }
            }
        }
        
        return -params;
    }
    
    /**
     * 
     * @param i The instruction.
     * @param cpool The constant pool.
     * @return The modification of the size of the stack by the instruction.
     */
    public static int getInstructionStackModif(Instruction i,
            ConstantPool cpool){
        // Note: this does *not* use BCEL inbuilt method as that works with
        // longs and doubles taking up 2 stack spaces, which makes calculations
        // difficult/impossible.
        switch(i.getOpcode()){
            case AASTORE:
            case BASTORE:
            case CASTORE:
            case DASTORE:
            case FASTORE:
            case IASTORE:
            case LASTORE:
            case SASTORE:
                return -3;
            case IF_ACMPEQ: // Is a jump opcode.
            case IF_ACMPNE: // Is a jump opcode.
            case IF_ICMPEQ: // Is a jump opcode.
            case IF_ICMPNE: // Is a jump opcode.
            case IF_ICMPLT: // Is a jump opcode.
            case IF_ICMPGE: // Is a jump opcode.
            case IF_ICMPGT: // Is a jump opcode.
            case IF_ICMPLE: // Is a jump opcode.
            case POP2:
            case PUTFIELD:
                return -2;
            case AALOAD:
            case ARETURN: // Is a jump opcode.
            case ASTORE:
            case ASTORE_0:
            case ASTORE_1:
            case ASTORE_2:
            case ASTORE_3:
            case BALOAD:
            case CALOAD:
            case DADD:
            case DALOAD:
            case DCMPG:
            case DCMPL:
            case DDIV:
            case DMUL:
            case DREM:
            case DRETURN: // Is a jump opcode.
            case DSTORE:
            case DSTORE_0:
            case DSTORE_1:
            case DSTORE_2:
            case DSTORE_3:
            case DSUB:
            case FADD:
            case FALOAD:
            case FCMPG:
            case FCMPL:
            case FDIV:
            case FMUL:
            case FREM:
            case FRETURN: // Is a jump opcode.
            case FSTORE:
            case FSTORE_0:
            case FSTORE_1:
            case FSTORE_2:
            case FSTORE_3:
            case FSUB:
            case IADD:
            case IALOAD:
            case IAND:
            case IDIV:
            case IFEQ: // Is a jump opcode.
            case IFNE: // Is a jump opcode.
            case IFLT: // Is a jump opcode.
            case IFGE: // Is a jump opcode.
            case IFGT: // Is a jump opcode.
            case IFLE: // Is a jump opcode.
            case IFNONNULL: // Is a jump opcode.
            case IFNULL: // Is a jump opcode.
            case IMUL:
            case IOR:
            case IREM:
            case IRETURN: // Is a jump opcode.
            case ISHL:
            case ISHR:
            case ISTORE:
            case ISTORE_0:
            case ISTORE_1:
            case ISTORE_2:
            case ISTORE_3:
            case ISUB:
            case IUSHR:
            case IXOR:
            case LADD:
            case LALOAD:
            case LAND:
            case LCMP:
            case LDIV:
            case LMUL:
            case LOOKUPSWITCH: // Is a jump opcode.
            case LOR:
            case LREM:
            case LRETURN: // Is a jump opcode.
            case LSHL:
            case LSHR:
            case LSTORE:
            case LSTORE_0:
            case LSTORE_1:
            case LSTORE_2:
            case LSTORE_3:
            case LSUB:
            case LUSHR:
            case LXOR:
            case MONITORENTER:
            case MONITOREXIT:
            case POP:
            case PUTSTATIC:
            case SALOAD:
            case TABLESWITCH: // Is a jump opcode.
                return -1;
            case ANEWARRAY:
            case ARRAYLENGTH:
            case ATHROW: // Is a jump opcode.
            case CHECKCAST:
            case D2F:
            case D2I:
            case D2L:
            case DNEG:
            case F2D:
            case F2I:
            case F2L:
            case FNEG:
            case GETFIELD:
            case GOTO: // Is a jump opcode.
            case GOTO_W: // Is a jump opcode.
            case I2B:
            case I2C:
            case I2D:
            case I2F:
            case I2L:
            case I2S:
            case IINC:
            case INEG:
            case INSTANCEOF:
            case L2D:
            case L2F:
            case L2I:
            case LNEG:
            case NEWARRAY:
            case NOP:
            case RET: // Is a jump opcode.
            case RETURN: // Is a jump opcode.
            case SWAP:
                return 0;
            case ACONST_NULL:
            case ALOAD:
            case ALOAD_0:
            case ALOAD_1:
            case ALOAD_2:
            case ALOAD_3:
            case BIPUSH:
            case DCONST_0:
            case DCONST_1:
            case DLOAD:
            case DLOAD_0:
            case DLOAD_1:
            case DLOAD_2:
            case DLOAD_3:
            case DUP:
            case DUP_X1:
            case DUP_X2:
            case FCONST_0:
            case FCONST_1:
            case FCONST_2:
            case FLOAD:
            case FLOAD_0:
            case FLOAD_1:
            case FLOAD_2:
            case FLOAD_3:
            case GETSTATIC:
            case ICONST_M1:
            case ICONST_0:
            case ICONST_1:
            case ICONST_2:
            case ICONST_3:
            case ICONST_4:
            case ICONST_5:
            case ILOAD:
            case ILOAD_0:
            case ILOAD_1:
            case ILOAD_2:
            case ILOAD_3:
            case JSR: // Is a jump opcode.
            case JSR_W: // Is a jump opcode.
            case LCONST_0:
            case LCONST_1:
            case LDC:
            case LDC_W:
            case LDC2_W:
            case LLOAD:
            case LLOAD_0:
            case LLOAD_1:
            case LLOAD_2:
            case LLOAD_3:
            case NEW:
            case SIPUSH:
                return 1;
            case DUP2:
            case DUP2_X1:
            case DUP2_X2:
                return 2;
            case INVOKEDYNAMIC:
                // TODO: maybe add support sometime.
                throw new IllegalArgumentException("Invokedynamic is NOT " +
                        "SUPPORTED for patching.");
            case INVOKEINTERFACE:
            case INVOKESPECIAL:
            case INVOKEVIRTUAL:
                return getMethodStackModif((CPInstruction)i, cpool) - 1;
            case INVOKESTATIC:
                return getMethodStackModif((CPInstruction)i, cpool);
            case MULTIANEWARRAY:
                return 1 - ((org.apache.bcel.generic.MULTIANEWARRAY)i).
                        getDimensions();
            // Fortunately: WIDE is not required, as BCEL's Instruction List
            // filters it out.
            default:
                throw new IllegalArgumentException("Uknown opcode: " +
                        i.getOpcode());
        }
    }
    
    /**
     * 
     * @param i An instruction.
     * @return Whether or not the instruction is a jump instruction, i.e.
     *     the code might not continue at the following instruction.
     */
    public static boolean isJump(Instruction i){
        return i instanceof BranchInstruction ||
                i instanceof org.apache.bcel.generic.RET;
    }
    
    /**
     * Indicates a possible execution path through an Instruction list.
     * 
     * THIS CLASS IS DESIGNED ONLY AS A HELPER FOR getTopLevelInstructions()!
     * IT DOES NOT GENERATE ALL POSSIBLE PATHS, AND PRUNES OFF CERTAIN PARTS!
     * @author Thomas Kerber
     * @version 1.0.0
     */
    private static class InstructionPath implements Cloneable{
        
        /**
         * The instruction handles which have been visited.
         * 
         * NOTE: This must be a field, even though it is only used in walk(),
         * as it must be cloned and modified in the cloned version.
         */
        private Set<InstructionHandle> ihlvisited;
        /**
         * The instruction handles traversed.
         */
        private List<InstructionHandle> ihl;
        /**
         * The list of possible paths in the instruction list.
         */
        private List<InstructionPath> parallelPaths =
                new ArrayList<InstructionPath>();
        
        /**
         * 
         * @param il The instruction list to traverse.
         */
        public InstructionPath(InstructionList il){
            ihlvisited = new HashSet<InstructionHandle>();
            ihlvisited.add(il.getStart());
            ihl = new ArrayList<InstructionHandle>();
            ihl.add(il.getStart());
            parallelPaths.add(this);
            walk();
        }
        
        /**
         * 
         * @return The list of possible paths in the instruction list.
         */
        public List<InstructionPath> getParallelPaths(){
            return parallelPaths;
        }
        
        /**
         * 
         * @return The list of traversed instruction handles.
         */
        public List<InstructionHandle> getIhl(){
            return ihl;
        }
        
        /**
         * Traversed the instruction list.
         */
        private void walk(){
            InstructionHandle currentlyAt = ihl.get(ihl.size() - 1);
            while(true){
                InstructionHandle next = currentlyAt.getNext();
                // TODO mebbe add explicit jsr / ret support.
                if(currentlyAt instanceof BranchHandle){
                    InstructionHandle target =
                            ((BranchHandle)currentlyAt).getTarget();
                    // If instructions may be branched
                    if(currentlyAt.getInstruction() instanceof IfInstruction){
                        if(!ihl.contains(target)){
                            InstructionPath branch = this.clone();
                            parallelPaths.add(branch);
                            branch.ihl.add(target);
                            branch.ihlvisited.add(target);
                            branch.walk();
                        }
                        else{
                            // Everything from target to currentlyAt may be
                            // traversed several times. Thus, they are removed
                            // from ihl, but NOT from ihlvisited (this is the
                            // reason ihlvisited exists)
                            int segmentStart = ihl.indexOf(target);
                            int segmentLength = ihl.indexOf(currentlyAt) -
                                    segmentStart + 1;
                            for(int _ = 0; _ < segmentLength; _++)
                                ihl.remove(segmentStart);
                        }
                    }
                    else
                        next = target;
                }
                if(next == null)
                    break;
                if(ihl.contains(next)){
                    parallelPaths.remove(this);
                    ihl = new ArrayList<InstructionHandle>();
                    return;
                }
                ihl.add(next);
                ihlvisited.add(next);
                currentlyAt = next;
            }
        }
        
        @Override
        public InstructionPath clone(){
            Object o = null;
            try{
                o = super.clone();
            }
            catch(CloneNotSupportedException e){}
            return (InstructionPath)o;
        }
        
    }
    
    /**
     * Gets all top-level instructions in an instruction list.
     * That is to say, all instructions which are guaranteed to be executed
     * once, and only once. (Ignoring exceptions)
     * @param il The instruction list to analyse.
     * @return The list of handles which are top-level. A null in this list
     *     indicates a jump in the instruction list.
     */
    public static List<InstructionHandle> getTopLevelInstructions(
            InstructionList il){
        List<InstructionHandle> ret = new ArrayList<InstructionHandle>();
        
        List<InstructionPath> paths = new InstructionPath(il).
                getParallelPaths();
        // Top level instructions are those contained in all "paths" (although
        // technically not all paths are there)
        Set<InstructionHandle> retain = new HashSet<InstructionHandle>();
        if(paths.size() != 0)
            retain.addAll(paths.get(0).getIhl());
        for(int i = 1; i < paths.size(); i++){
            retain.retainAll(paths.get(i).getIhl());
        }
        
        boolean lastWasNull = false;
        for(InstructionHandle i = il.getStart(); i != null; i = i.getNext()){
            if(retain.contains(i)){
                ret.add(i);
                lastWasNull = false;
            }
            else{
                // When something was cut out of the list a single null is
                // added to indicate that.
                if(!lastWasNull)
                    ret.add(null);
                lastWasNull = true;
            }
        }
        
        return ret;
    }
    
    /**
     * Shifts local variables.
     * @param il The instruction list to shift.
     * @param from The local variable indexes to shift from.
     * @param to The target local variable indexes. (Must have the same
     *     length as from)
     */
    public static void shiftLocalVariables(InstructionList il, int[] from,
            int[] to){
        
        for(InstructionHandle ih : il.getInstructionHandles()){
            if(ih.getInstruction() instanceof LocalVariableInstruction){
                LocalVariableInstruction lvi =
                        (LocalVariableInstruction)ih.getInstruction();
                for(int i = 0; i < from.length; i++){
                    if(lvi.getIndex() == from[i]){
                        lvi.setIndex(to[i]);
                        break;
                    }
                }
            }
        }
        
    }
    
}
